#include "stdafx.h"
#include "GOST_Cipher.h"
const uint GOST_Cipher::S[GOST_Cipher::EIGHT][GOST_Cipher::INSUB]=
{
	{ 4,10, 9, 2,13, 8, 0,14, 6,11, 1,12, 7,15, 5, 3},
	{14,11, 4,12, 6,13,15,10, 2, 3, 8, 1, 0, 7, 5, 9},
	{ 5, 8, 1,13,10, 3, 4, 2,14,15,12, 7, 6, 0, 9,11},
	{ 7,13,10, 1, 0, 8, 9,15,14, 4, 6,12,11, 2, 5, 3},
	{ 6,12, 7, 1, 5,15,13, 8, 4,10, 9,14, 0, 3,11, 2},
	{ 4,11,10, 0, 7, 2, 1,13, 3, 6, 8, 5, 9,12,15,14},
	{13,11, 4, 1, 3,15, 5, 9, 0,10,14, 7, 6, 8, 2,12},
	{ 1,15,13, 0, 5, 7,10, 4, 9, 2, 3,14, 6,11, 8,12}
};
void GOST_Cipher::init(const uint key[])
{
	for(uint i=0;i<KEY_LENGTH;++i)
		KEY[i]=key[i];
	plain=new text;
	cipher=new text;
}
GOST_Cipher::GOST_Cipher(const uint key[]):iv()
{
	init(key);
}
GOST_Cipher::GOST_Cipher(const uint key[KEY_LENGTH], 
					     const init_vector in):iv(in)
{
	init(key);
}

//==========================CIPHER==========================//
uint GOST_Cipher::f(const uint a, const uint k) const
{
	uint sum=(a+k);
	uint res=0;
	for(uint i=0;i<EIGHT;++i)
	{
		uint state=(sum<<(BITS_IN_UINT-SUB*(i+1)))>>(BITS_IN_UINT-SUB);
		state=S[i][state];
		res+=state<<SUB*i;
	}
	res=(res<<SHIFT)+(res>>(BITS_IN_UINT-SHIFT));
	return res;
}

block GOST_Cipher::feistel(const block x,ubyte iter, DIRECTION dir)const
{
	if(!(iter<ROUNDS)) return x;
	switch(dir)
	{
	case ENCIPHER:
		if(iter<ROUNDS-KEY_LENGTH) 
			iter=iter%KEY_LENGTH;
		else
			iter=KEY_LENGTH-iter%KEY_LENGTH-1;
		break;
	case DECIPHER:
		if(iter<KEY_LENGTH) 
			iter=iter%KEY_LENGTH;
		else
			iter=KEY_LENGTH-iter%KEY_LENGTH-1;
		break;
	}
	block res;
	res.l=x.r^f(x.l,KEY[iter]);
	res.r=x.l;
	return res;
}
block GOST_Cipher::E(const block x)
{
	block in=x;
	for(uint j=0;j<ROUNDS;++j)
		{
			in=feistel(in,j, ENCIPHER);
		}
	in.swap();
	return in;
}
block GOST_Cipher::D(const block x)
{
	block in=x;
	for(uint j=0;j<ROUNDS;++j)
		{
			in=feistel(in,j, DECIPHER);
		}
	in.swap();
	return in;
}
text *GOST_Cipher::Encipher(const text &plain, MODE mode)
{
	cipher->clear();
	block s=iv.val[0];
	for(ptext i=plain.begin();i<plain.end();++i)
	{
		block x=*i;
		block y;
		switch(mode)
		{
		case ECB:
			y=E(x);
			break;
		case CFB:
			y=x^E(s);
			s=y;
			break;
		case CBC:
			y=x^E(s++);
			break;
		case TEMK:
			y=temk_E(x);
			break;
		case FE:
			y=fe_E(x);
			break;
		}
		cipher->push_back(y);
	}
	return cipher;
}
text *GOST_Cipher::Decipher(const text &cipher, MODE mode)
{
	plain->clear();
	block s=iv.val[0];
	for(ptext i=cipher.begin();i<cipher.end();++i)
	{
		block y=*i;
		block x;
		switch(mode)
		{
		case ECB:
			x=D(y);
			break;
		case CFB:
			x=y^E(s);
			s=y;
			break;
		case CBC:
			x=y^E(s++);
			break;
		case TEMK:
			x=temk_D(y);
			break;
		case FE:
			x=fe_D(y);
			break;
		}
		plain->push_back(x);
	}
	return plain;
}
void GOST_Cipher::set_key(const uint key[KEY_LENGTH])
{
	for(uint i=0;i<KEY_LENGTH;++i)
		KEY[i]=key[i];
}
void GOST_Cipher::set_key(const init_vector key)
{
	for(uint i=0;i<IV_X;++i)
	{KEY[2*i]=key.val[i].l;KEY[2*i+1]=key.val[i].r;}
}
init_vector GOST_Cipher::Get_Hash(const text &plain)
{
	//store key
	uint old_key[KEY_LENGTH];
	for(ubyte i=0;i<KEY_LENGTH;++i) old_key[i]=KEY[i];

	init_vector h=iv,p;
	ptext i;
	for(i=plain.begin();i<plain.end();i+=IV_X)
	{
		//end of file
		ubyte before_eof=plain.end()-i;
		bool eof=false;
		if(before_eof < IV_X)
		{
			for(ubyte j=0;j<before_eof;++j) p.val[j]=*(i+j);
			for(ubyte j=before_eof;j<IV_X;++j) p.val[j]=block(0,0);
			eof=true;
		}
		else
			for(ubyte j=0;j<IV_X;++j) p.val[j]=*(i+j);

		//encipher with key h
		set_key(h);
		//h(i)=E(h(i-1)+p(i))+h(i-1)
		for(ubyte j=0;j<IV_X;++j)
			h.val[j] = E(h.val[j]^p.val[j])^h.val[j];

		if(eof) break;
	}
	//return old key
	for(ubyte i=0;i<KEY_LENGTH;++i) KEY[i]=old_key[i];
	return h;
}



const block T1(1,2),T2(2,3),T3(3,4);
block GOST_Cipher::temk_E(block x)
{
	block K1=E(T1),K2=E(T2),K3=E(T3);
	///X2 X1
	for(ubyte i=0;i<SINGLE_KEY_LENGTH;++i)
	{
		uint temp=KEY[i];
		KEY[i]=KEY[i+SINGLE_KEY_LENGTH];
		KEY[i+SINGLE_KEY_LENGTH]=temp;
	}
	K1=D(K1);K2=D(K2);K3=D(K3);
	///X1 X2
	for(ubyte i=0;i<SINGLE_KEY_LENGTH;++i)
	{
		uint temp=KEY[i];
		KEY[i]=KEY[i+SINGLE_KEY_LENGTH];
		KEY[i+SINGLE_KEY_LENGTH]=temp;
	}
	K1=E(K1);K2=E(K2);K3=E(K3);

	uint key1[]={0,0,0,0,0,0,K1.l,K1.r};
	uint key2[]={0,0,0,0,0,0,K2.l,K2.r};
	uint key3[]={0,0,0,0,0,0,K3.l,K3.r};
	GOST_Cipher c1(key1),c2(key2),c3(key3);

	return c3.E(c2.D(c1.E(x)));
}
block GOST_Cipher::temk_D(block x)
{
	block K1=E(T1),K2=E(T2),K3=E(T3);
	///X2 X1
	for(ubyte i=0;i<SINGLE_KEY_LENGTH;++i)
	{
		uint temp=KEY[i];
		KEY[i]=KEY[i+SINGLE_KEY_LENGTH];
		KEY[i+SINGLE_KEY_LENGTH]=temp;
	}
	K1=D(K1);K2=D(K2);K3=D(K3);
	///X1 X2
	for(ubyte i=0;i<SINGLE_KEY_LENGTH;++i)
	{
		uint temp=KEY[i];
		KEY[i]=KEY[i+SINGLE_KEY_LENGTH];
		KEY[i+SINGLE_KEY_LENGTH]=temp;
	}
	K1=E(K1);K2=E(K2);K3=E(K3);

	uint key1[]={0,0,0,0,0,0,K1.l,K1.r};
	uint key2[]={0,0,0,0,0,0,K2.l,K2.r};
	uint key3[]={0,0,0,0,0,0,K3.l,K3.r};
	GOST_Cipher c1(key1),c2(key2),c3(key3);

	return c1.D(c2.E(c3.D(x)));
}
/////////////////////////////////
block GOST_Cipher::fe_E(block x)
{
	uint key1[SINGLE_KEY_LENGTH],
		 key2[SINGLE_KEY_LENGTH],
		 key3[SINGLE_KEY_LENGTH];
	for(ubyte i=0;i<SINGLE_KEY_LENGTH;++i)
	{
		key1[i]=KEY[i];
		key2[i]=KEY[i+SINGLE_KEY_LENGTH];
		key3[i]=KEY[i+2*SINGLE_KEY_LENGTH];
	}
	GOST_Cipher c1(key1),c2(key2),c3(key3);

	return c1.E(c2.D(c3.E(c2.D(c1.E(x)))));
}
block GOST_Cipher::fe_D(block x)
{
	uint key1[SINGLE_KEY_LENGTH],
		 key2[SINGLE_KEY_LENGTH],
		 key3[SINGLE_KEY_LENGTH];
	for(ubyte i=0;i<SINGLE_KEY_LENGTH;++i)
	{
		key1[i]=KEY[i];
		key2[i]=KEY[i+SINGLE_KEY_LENGTH];
		key3[i]=KEY[i+2*SINGLE_KEY_LENGTH];
	}
	GOST_Cipher c1(key1),c2(key2),c3(key3);

	return c1.D(c2.E(c3.D(c2.E(c1.D(x)))));
}